查看原文
其他

科普一下那些客户端上的 AI 库

fundroid AndroidPub 2022-07-13
有人说人工智能会是继互联网之后的下一次工业革命,不可否认,大到汽车、小到手表,AI 技术已经广泛应用在我们周围,随便一个 App 都试图跟 AI 发生点关系以证明自己的与时俱进。

AI 的普及为移动端开发带来基于和挑战。挑战在于移动端技术相对于 AI 等新兴技术正被逐渐边缘化;机遇在于移动设备正逐渐成为 AI 技术新的载体,让用户更快捷地享受其带来的便利。

边缘计算(Edge AI)

“模型” 是对一组训练数据应用机器学习算法而得到的结果。使用模型对一些输入的数据进行预测的过程叫“推理”。有很多依靠编写代码仅能低效甚至很难完成的任务,使用模型推理能更好地完成。例如,可以训练模型来归类照片,或者识别照片内的特定对象等。

长久以来,模型推理大多运行在服务端,客户端只作为结果展示的载体。但随着移动端的硬件水平的提升,一些深度学习的数据模型可以以二进制形式下载到手机上进行推理,实施得到 AI 计算的结果,“边缘计算”的概念也由此诞生。

边缘计算有以下好处

  • 数据本地化,解决云端存储及隐私问题;
  • 计算本地化,解决云端计算过载问题;
  • 低通信成本,解决交互和体验问题;
  • 去中心化计算,故障规避与极致个性化。

移动端的机器学习库

边缘计算需要借助移动端的机器学习(Machine Learning,后文简称 ML)库实现。

移动端 ML 库的主要工作在于模型下发和移动端的模型推理,由于移动端算力相对较弱,因此模型文件不宜过大,需要进行一些裁剪和压缩。在人工智能技术越来越普及的今天,作为客户端研发也需要了解一些常见的机器学习库。目前主流的库有 Tensorflow LitePyTorch MobileMediaPipeFirebase ML Kit 等,接下来就这些技术做一个简单介绍,帮大家扩大技术视野。

TensorFlow Lite

https://www.tensorflow.org/lite/guide

TensorFlow Lite 是将 TensorFlow 用于移动设备和嵌入式设备的轻量级解决方案,可以在 Android、iOS 以及其他嵌入式系统上使用。 借由 TensorFlow Lite Converter 将模型经过压缩转换成.tflite 格式。TFLite Converter 提供 Python 和 CLI 工具,推荐使用 Python API。

下面代码是通过 TFLite Converter 将 tensorflowkeras 的模型文件转换为 tflite 格式

import tensorflow as tf

# 转换 saved_model
converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
tflite_model = converter.convert()
with open('model.tflite''wb'as f:
    f.write(tflite_model)

# 转换 keras_model
keras_model = tf.keras.models.load_model(filepath)
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
keras_tflite_model = converter.convert()
with open('keras_model.tflite''wb'as f:
    f.write(keras_tflite_model)

转换后的模型文件可以下载到手机上使用。 以 Android 为例,代码如下:

class TFLiteActivity : AppCompatActivity() {
    /* 模型下载 */
    private fun initializeTFLite(device: String = "NNAPI", numThreads: Int = 4) {
        val delegate = when (device) {
            "NNAPI" -> NnApiDelegate()
            "GPU" -> GpuDelegate()
            "CPU" -> "" }
        if (delegate != "") tfliteOptions.addDelegate(delegate)

        tfliteOptions.setNumThreads(numThreads)
        tfliteModel = FileUtil.loadMappedFile(this, tflite_model_path)
        tfliteInterpreter = Interpreter(tfliteModel, tfliteOptions)
        inputImageBuffer = TensorImage(tfliteInterpreter.getInputTensor(0).dataType())
        outputProbabilityBuffer = TensorBuffer.createFixedSize(
            tfliteInterpreter.getOutputTensor(0).shape(),
            tfliteInterpreter.getInputTensor(0).dataType())

        probabilityProcessor = TensorProcessor
            .Builder()
            .add(NormalizeOp(0.0f1.0f))
            .build()
    }

    /* 处理处理 */
    @WorkerThread
    override fun analyzeImage(image: ImageProxy, rotationDegrees: Int): Map<String, Float> {
        val bitmap = Utils.imageToBitmap(image)
        val cropSize = Math.min(bitmap.width, bitmap.height)
        inputImageBuffer.load(bitmap)
        val inputImage = ImageProcessor
            .Builder()
            .add(ResizeWithCropOrPadOp(cropSize, cropSize))
            .add(ResizeOp(224224, ResizeMethod.NEAREST_NEIGHBOR))
            .add(NormalizeOp(127.5f127.5f))
            .build()
            .process(inputImageBuffer)

        tfliteInterpreter.run(inputImage!!.buffer, outputProbabilityBuffer.buffer.rewind())
        val labeledProbability: Map<String, Float> = TensorLabel(
            labelsList, probabilityProcessor.process(outputProbabilityBuffer)
        ).mapWithFloatValue
        return labeledProbability
    }
}

上面的代码通俗易懂,因为 TFLite 的 API 非常易用,即使没有太多的移动端开发经验也能轻松驾驭。

PyTorch Mobile

https://pytorch.org/mobile/home/

Facebook 于 19 年底发布了 PyTorch Mobile 作为 PyTorch 的移动端解决方案。PyTorch Mobile 可以将 Pytorchscript 的模型进行 JIT 编译,得到 .pt 格式的文件供移动端使用。20年,PyTorch Developer Day 宣布开始支持 Android 的 NNAPI 和 iOS 的 MetalAPI。

PyTorch 中使用 torch.jit.trace 处理模型转换

import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(13224224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

经转换的模型可以在 Android 和 iOS 中加载。 以 Android 为例,代码如下:

class PyTorchActivity : AppCompatActivity() {
    /* 模型下载 */
    private fun initializePyTorch() {
        val pytorchModule = Module.load(Utils.assetFilePath(
            this,
            pytorch_mobile_model_path))
        val mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * 224 * 224)
        val mInputTensor = Tensor.fromBlob(
            mInputTensorBuffer,
            longArrayOf(13224L224L)
        )
    }

    /* 模型处理 */
    @WorkerThread
    override fun analyzeImage(image: ImageProxy, rotationDegrees: Int): Map<String, Float> {
        TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
            image.image,
            rotationDegrees,
            224,
            224,
            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
            TensorImageUtils.TORCHVISION_NORM_STD_RGB,
            mInputTensorBuffer,
            0
        )
        val outputModule = pytorchModule.forward(IValue.from(mInputTensor)).toTensor()
        val scores = outputModule.dataAsFloatArray
        val labeledProbability: MutableMap<String, Float> = mutableMapOf()
        for (i in 0 until labelsList.size - 1) {
            labeledProbability[labelsList[i + 1]] = score[i]
        }
        return labeledProbability
    }
}

MediaPipe

https://google.github.io/mediapipe/

MediaPipe 与 Tensorflow Lite 和 PyTorch Mobile 不同,不是从已有的深度学习库派生出来的。MediaPipe 专注于计算机视觉和多媒体处理的 ML管道框架,在2019年6月举行的 CVPR 大会,MeidaPipe 正式开源,自那以后,谷歌陆续发布了一系列的 ML 管道示例。MediaPipe 为 Android、iOS 等多平台提供了人脸、物体检测、动作捕捉等能力。

MediaPipe 图形库可以通过 bazel 编译成供 Androd端使用的 .aar 或者 iOS 的 .ipa

# MediaPipe graph that performs face mesh with TensorFlow Lite on GPU.

# GPU buffer. (GpuBuffer)
input_stream: "input_video"

# Output image with rendered results. (GpuBuffer)
output_stream: "output_video"
# Detected faces. (std::vector<Detection>)
output_stream: "face_detections"

# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for downstream nodes
# (calculators and subgraphs) in the graph to finish their tasks before it
# passes through another image. All images that come in while waiting are
# dropped, limiting the number of in-flight images in most part of the graph to
# 1. This prevents the downstream nodes from queuing up incoming images and data
# excessively, which leads to increased latency and memory usage, unwanted in
# real-time mobile applications. It also eliminates unnecessarily computation,
# e.g., the output produced by a node may get dropped downstream if the
# subsequent nodes are still busy processing previous inputs.
node {
  calculator: "FlowLimiterCalculator"
  input_stream: "input_video"
  input_stream: "FINISHED:output_video"
  input_stream_info: {
    tag_index: "FINISHED"
    back_edge: true
  }
  output_stream: "throttled_input_video"
}

# Subgraph that detects faces.
node {
  calculator: "FaceDetectionFrontGpu"
  input_stream: "IMAGE:throttled_input_video"
  output_stream: "DETECTIONS:face_detections"
}

# Converts the detections to drawing primitives for annotation overlay.
node {
  calculator: "DetectionsToRenderDataCalculator"
  input_stream: "DETECTIONS:face_detections"
  output_stream: "RENDER_DATA:render_data"
  node_options: {
    [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
      thickness: 4.0
      color { r: 255 g: 0 b: 0 }
    }
  }
}

# Draws annotations and overlays them on top of the input images.
node {
  calculator: "AnnotationOverlayCalculator"
  input_stream: "IMAGE_GPU:throttled_input_video"
  input_stream: "render_data"
  output_stream: "IMAGE_GPU:output_video"
}

上面是人脸识别的处理过程, 经历了图像输入、人脸检测、特征点绘制等一系列过程,这也正好是 MediaPipe 的特点,不只是简单的模型推理,而是能够将输入、前处理、推理、后处理、输出等一系列流程进行组合编排。

Firebase ML Kit

https://firebase.google.com/docs/ml

Firebase 是基于 Google Mobile Service 的移动开放平台,为移动开发者提供了 APM、埋点等功能,Firebase ML Kit 是 Firebase 提供的面向移动端的机器学习库, 为 Android/iOS 平台提供模型的分发、推理、学习、日志收集等能力,目前只支持 TFLite 格式的模型。

例如使用 ML Kit 识别图片中的物体,代码如下:

private class ObjectDetection : ImageAnalysis.Analyzer {
    val options = FirebaseVisionObjectDetectorOptions.Builder()
        .setDetectorMode(FirebaseVisionObjectDetectorOptions.STREAM_MODE)
        .enableClassification()
        .build()
    val objectDetector = FirebaseVision.getInstance().getOnDeviceObjectDetector(options)

    private fun degreesToFirebaseRotation(degrees: Int): Int = when(degrees) {
        0 -> FirebaseVisionImageMetadata.ROTATION_0
        90 -> FirebaseVisionImageMetadata.ROTATION_90
        180 -> FirebaseVisionImageMetadata.ROTATION_180
        270 -> FirebaseVisionImageMetadata.ROTATION_270
        else -> throw Exception("Rotation must be 0, 90, 180, or 270.")
    }

    override fun analyze(imageProxy: ImageProxy?, degrees: Int) {
        val mediaImage = imageProxy?.image
        val imageRotation = degreesToFirebaseRotation(degrees)
        if (mediaImage != null) {
            val image = FirebaseVisionImage.fromMediaImage(mediaImage, imageRotation)
            objectDetector.processImage(image)
                    .addOnSuccessListener { detectedObjects ->
                        for (obj in detectedObjects) {
                            val id = obj.trackingId
                            val bounds = obj.boundingBox
                            val category = obj.classificationCategory
                            val confidence = obj.classificationConfidence
                            // Do Something
                        }
                    }
                    .addOnFailureListener { e ->
                        // Do Something
                    }
        }
    }
}

上面介绍的 ML 库都需要下发推理引擎到移动端,其实 Android、iOS 也有自带的推理引擎

iOS (Core ML)

https://developer.apple.com/cn/documentation/coreml/ iOS提供了

iOS 提供 CoreML 可以将各种机器学习模型集成到应用中并进行推理。CoreML 不仅支持 TFLite 格式的模型,还支持 ONNX、Pytorch、XGBoost、Scikit-learn 等多种格式。这些模型通过 coremltools 转换为 CoreML 专用格式后加载到本地。iPhone 搭载了专用的神经网络处理器,可以低功耗地进行模型推理。想用 iOS 进行机器学习的话,CoreML 是一个好选择。

Android (NNAPI)

https://developer.android.com/ndk/guides/neuralnetworks

Android 端提供了 NNAPI(Android Neural Networks API)用于模型推理。 NNAPI 是 Android 8.1(API等级27)以后提供的专门处理机械学习的 Native 库。NNAPI 会根据手机当前的硬件性能、负荷状况等,将处理跑在特定设备上(GPU、DSP、专用处理器),当然也可以统一交由 CPU 执行。

Web端

Tensorflow.js

https://www.tensorflow.org/js

浏览器也可以进行机器学习和推理。Web浏览器中进行 AI 计算的主要语言是 Javascript、 使用的 ML 库是 Tensorflow.js 。学习和推理都要在浏览器执行,相对于 Android 、 iOS 性能上不占优,但是通过 WebGL 和 WASM 的辅助,也能满足基本的使用需求。

tf.keras 和 saved model 可以转换为 tensorflow.js 可处理的 json 格式,其中包含了神经网络结构和权重。

整个转换通过 tensorflowjs_converter 工具进行

# saved_model转换
tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

# Keras_model変換
tensorflowjs_converter \
    --input_format keras \
    path/to/my_model.h5 \
    path/to/tfjs_target_dir


tensorflow.js 可以嵌入到 html 中使用,代码如下:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.1"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>

<img id="img" src="cat.jpg"></img>

<script>
  const img = document.getElementById('img');
  // Load the model.
  mobilenet.load().then(model => {
    // Classify the image.
    model.classify(img).then(predictions => {
      console.log('Predictions: ');
      console.log(predictions);
    });
  });
</script>


当然也可以在js中使用

import * as tf from "@tensorflow/tfjs";

import { IMAGENET_CLASSES } from "./imagenet_classes";

const MOBILENET_MODEL_PATH =
  "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";

const IMAGE_SIZE = 224;
const TOPK_PREDICTIONS = 10;

let mobilenet;
const mobilenetDemo = async () => {
  status("Loading model...");

  mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
  mobilenet.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])).dispose();

  status("");

  const catElement = document.getElementById("cat");
  if (catElement.complete && catElement.naturalHeight !== 0) {
    predict(catElement);
    catElement.style.display = "";
  } else {
    catElement.onload = () => {
      predict(catElement);
      catElement.style.display = "";
    };
  }
};

async function predict(imgElement) {
  status("Predicting...");
  const logits = tf.tidy(() => {
    const img = tf.browser.fromPixels(imgElement).toFloat();

    const offset = tf.scalar(127.5);

    const normalized = img.sub(offset).div(offset);

    const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);

    return mobilenet.predict(batched);
  });
}

ml5.js

https://learn.ml5js.org/#/

简单来说,ml5.js 是对 tensorflow.js 的封装,它的 API 更加简单易懂,适合 ML 学者使用。ml5.js 提供了 API 用于在图像、语言、声音等媒体中进行识别和变换等 AI 能力,而且 API 风格更加贴近 Javascript 的习惯。

下面是使用 ml5.js 进行推理的代码示例:

let classifier;

let img;

function preload() {
  classifier = ml5.imageClassifier("MobileNet");
  img = loadImage("images/bird.png");
}

function setup() {
  createCanvas(400400);
  classifier.classify(img, gotResult);
  image(img, 00);
}

function gotResult(error, results) {
  if (error) {
    console.error(error);
  } else {
    console.log(results);
    createDiv(`Label: ${results[0].label}`);
    createDiv(`Confidence: ${nf(results[0].confidence, 0, 2)}`);
  }
}

模型压缩

前文提到过,模型压缩更利于端上进行推理,最后介绍几种常见的模型压缩方法:

  • 参数量化(Parameter Quantization)
  • 网络剪枝(Network Pruning)
  • 知识蒸馏(Knowledge Distillation)

量化(Quantization)

量化就是使用更少的bits来表示一个参数。例如在创建模型的时候使用32位浮点数进行学习,在进行推理之前,将其转化为16位、8位、甚至1位(boolean)。通过减少模型体积降低计算量。

当然量化会导致精度的劣化,计算结果或产生偏差,这也是一种面向性能的折中。

剪枝(Pruning)

神经网络中的一些冗余的权重和神经元是可以被剪枝的,因为这些权重较低或者神经元的输出大多数时候为零。通过删除这些内容可以减轻模型的重量。另外通过共享的手段,在多个节点之间共享权重,也可以减少模型的容量。

跟量化一样,删除一些权重可能导致精度劣化。

蒸馏(Distillation)

蒸馏是通过在学习方法上下功夫来提高压缩模型精度的方法。

蒸馏现以高精度的大容量模型进行学习,以大容量模型的推理结果作为特征参考,参与轻量模型的计算中,以提高准确率。由于大容量模型的推理结果是一个标签的概率分布,所以轻量模型从标签的概率中学习各个数据标签的相似性。

例如,有一张猫的画像,分布在猫60%、狗30%、兔子10%的情况下,猫的图像是猫:狗:兔子=6:3:1的特征。

总结

从广义上说,边缘计算是个宽泛的概念,本文介绍的边缘 AI(Edge AI)只是其中的一个方向。边缘 AI 促使了人工智能从云智能向端智能的延伸。在云智能中,网络是智能决策的底限成本,基于云智能的业务需要应对QPS、网络延迟等问题;而在端智能中,没有网络环境的制约,这意味着更高的决策频率。也正因如此,移动端仍长期具有旺盛的生命力,并将以新的姿态走得更远。



(完)



推荐阅读
看了这篇文章,终于搞懂了 Android 存储
Jetpack MVVM七宗罪 之三 :在 onViewCreated 中请求数据
Jetpack MVVM七宗罪 之二 :使用 luanchWhenX 启动协程
Jetpack MVVM七宗罪 之一 :还在使用 Fragment 作为 LifecycleOwner ?


加好友拉你进群,技术干货聊不停


↓关注公众号↓↓添加微信交流↓



您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存